"""
Analysis routines for kernel‑to‑metric simulation outputs.

This module provides helpers to compute radial averages, fit power laws
on log–log scales, estimate bootstrap errors, analyse Gauss‑law
plateaus, and compute lensing slopes.  These routines are used by the
runner script to summarise potentials and fields into compact
statistics that can be compared against theoretical expectations.
"""

from __future__ import annotations

import numpy as np
from typing import Tuple, Sequence


def radial_profile(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute the mean of a 2‑D field as a function of radius.

    The field is assumed to live on a square lattice centred at
    ``(L-1)/2``.  Radial distances are computed in lattice units and
    binned by the integer part of the distance.  For each integer
    radius ``r_idx`` the mean value and count of pixels belonging to
    that shell are returned.
    """
    L = data.shape[0]
    coords = np.arange(L, dtype=float) - (L - 1) / 2.0
    x, y = np.meshgrid(coords, coords, indexing="ij")
    r = np.sqrt(x**2 + y**2)
    r_flat = r.ravel()
    data_flat = data.ravel()
    r_idx = np.floor(r_flat).astype(int)
    max_idx = r_idx.max()
    sums = np.bincount(r_idx, weights=data_flat, minlength=max_idx + 1)
    counts = np.bincount(r_idx, minlength=max_idx + 1)
    radial_mean = np.zeros_like(sums)
    nonzero = counts > 0
    radial_mean[nonzero] = sums[nonzero] / counts[nonzero]
    radii = np.arange(max_idx + 1, dtype=float)
    return radii, radial_mean, counts


def fit_log_log(r: Sequence[float], f: Sequence[float]) -> Tuple[float, float, float]:
    """Fit a straight line to log–log data and return slope and R²."""
    r = np.asarray(r, dtype=float)
    f = np.asarray(f, dtype=float)
    mask = (r > 0) & (f > 0)
    r = r[mask]
    f = f[mask]
    if r.size < 2:
        return 0.0, 0.0, 0.0
    x = np.log(r)
    y = np.log(f)
    A = np.vstack([x, np.ones_like(x)]).T
    m, c = np.linalg.lstsq(A, y, rcond=None)[0]
    y_pred = m * x + c
    ss_res = np.sum((y - y_pred) ** 2)
    ss_tot = np.sum((y - y.mean()) ** 2)
    R2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
    return m, c, R2


def select_fit_window(radii: np.ndarray,
                      values: np.ndarray,
                      ell: int,
                      L: int,
                      min_bins: int = 5) -> Tuple[int, int, float, float, float]:
    """Automatically choose a radial fit window by maximising R².

    Constraints:
      r_min >= 3*ell
      r_max <= 0.25*L      # tightened from 0.30*L to reduce edge artefacts

    Returns (start_idx, end_idx, slope, intercept, R2).
    """
    r_min_allowed = max(int(np.ceil(3 * ell)), 1)
    r_max_allowed = int(np.floor(0.25 * L))           # <-- tightened window
    n_bins = len(radii)
    r_max_allowed = min(r_max_allowed, n_bins - 1)

    best_R2 = -np.inf
    best_start = r_min_allowed
    best_end = min(r_min_allowed + min_bins - 1, r_max_allowed)
    best_slope = 0.0
    best_intercept = 0.0

    for start in range(r_min_allowed, r_max_allowed + 1 - min_bins + 1):
        for end in range(start + min_bins - 1, r_max_allowed + 1):
            m, c, R2 = fit_log_log(radii[start:end + 1], values[start:end + 1])
            if R2 > best_R2:
                best_R2 = R2
                best_start = start
                best_end = end
                best_slope = m
                best_intercept = c
    return best_start, best_end, best_slope, best_intercept, best_R2


def bootstrap_slope_error(radii: np.ndarray,
                          values: np.ndarray,
                          start_idx: int,
                          end_idx: int,
                          n_resamples: int = 200,
                          seed: int = None) -> float:
    """Estimate the standard error of the slope via bootstrap resampling."""
    rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
    indices = np.arange(start_idx, end_idx + 1)
    slopes = np.empty(n_resamples, dtype=float)
    for i in range(n_resamples):
        sample_idx = rng.choice(indices, size=indices.size, replace=True)
        m, _, _ = fit_log_log(radii[sample_idx], values[sample_idx])
        slopes[i] = m
    return float(np.std(slopes, ddof=1))


def gauss_law_plateau(radii: np.ndarray,
                      E_r: np.ndarray,
                      start_idx: int,
                      end_idx: int) -> Tuple[float, float]:
    """Compute mean/std of Q_eff(r) = r^2 · E_r(r) over the window."""
    r_segment = radii[start_idx:end_idx + 1]
    E_segment = E_r[start_idx:end_idx + 1]
    Q_eff = (r_segment ** 2) * E_segment
    mean_Q = float(Q_eff.mean())
    std_Q = float(Q_eff.std(ddof=1)) if Q_eff.size > 1 else 0.0  # fixed
    return mean_Q, std_Q


def amplitude_at_radius(V_radial: np.ndarray,
                        radii: np.ndarray,
                        start_idx: int,
                        end_idx: int) -> float:
    """Prefactor estimate: V(r)*r at the median radius of the fit window."""
    mid_idx = (start_idx + end_idx) // 2
    r = radii[mid_idx]
    V_val = V_radial[mid_idx]
    return float(V_val * r)


def lensing_fit(b_vals: np.ndarray, alpha_vals: np.ndarray) -> Tuple[float, float]:
    """Fit α(b) ≈ slope · (1/b); return (slope, R²)."""
    mask = b_vals > 0
    b = b_vals[mask]
    alpha = alpha_vals[mask]
    x = 1.0 / b
    y = alpha
    A = np.vstack([x, np.ones_like(x)]).T
    m, c = np.linalg.lstsq(A, y, rcond=None)[0]
    y_pred = m * x + c
    ss_res = np.sum((y - y_pred) ** 2)
    ss_tot = np.sum((y - y.mean()) ** 2)
    R2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
    return float(m), float(R2)
